#!/usr/bin/python

# Copyright (c) 2009 Tom Pinckney
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation
# files (the "Software"), to deal in the Software without
# restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following
# conditions:
#
#     The above copyright notice and this permission notice shall be
#     included in all copies or substantial portions of the Software.
#
#     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
#     EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
#     OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
#     NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
#     HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
#     WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#     FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
#     OTHER DEALINGS IN THE SOFTWARE.

import sys
import os
import socket
import struct
import ConfigParser
import signal
import getopt

sys.path.append('/usr/lib/pymds')

from utils import *

class DnsError(Exception):
    pass

def serve(udps):
    #ns_resource_records, ar_resource_records = compute_name_server_resources(_name_servers)
    ns_resource_records = ar_resource_records = []
    while True:
        try:
            req_pkt, src_addr = udps.recvfrom(512)   # max UDP DNS pkt size
        except socket.error:
            continue
        qid = None
        try:
            exception_rcode = None
            try:
                qid, question, qtype, qclass = parse_request(req_pkt)
            except:
                exception_rcode = 1
                raise Exception("could not parse query")
            question = map(lambda x: x.lower(), question)
            found = False
            for config in config_files.values():
                if question[1:] == config['domain']:
                    query = question[0]
                elif question == config['domain']:
                    query = ''
		elif question[-len(config['domain']):] == config['domain']:
			query = question[:-len(config['domain'])]
                else:
                    continue
		ns_resource_records, IGNORE_ar_resource_records = compute_name_server_resources(config)
		# Copy the NS data for later.
		original_ns_resource_records = ns_resource_records

		# Perform the query against the backend plugin
		print "Checking with ", query
                rcode, an_resource_records = config['source'].get_response(query, config['domain'], qtype, qclass, src_addr)

		print "Got from get_response(): %s / %s" % (rcode, an_resource_records)
		# Run it against any filters we have available...
                if rcode == 0 and 'filters' in config:
                    for f in config['filters']:
                        print "Running filter!", query
                        an_resource_records, ns_resource_records = f.filter(query, config['domain'], qtype, qclass, src_addr, an_resource_records, ns_resource_records)
			print "Filter returned %s/%s" % (an_resource_records, ns_resource_records)

		# If there's no error and no responses, wipe NS information
		# (or, if it's an NS query, move it to the answer section)
		# If we don't do this, it'd be considered a referral instead of an answer...
		if rcode == 0 and len(an_resource_records) == 0 and ns_resource_records == original_ns_resource_records:
		    if qtype == 2:
		        an_resource_records = ns_resource_records
		    ns_resource_records = []

		print "About to send back:\nAN: %s\nNS: %s\nAR: %s\n" % (an_resource_records, ns_resource_records, ar_resource_records)
                resp_pkt = format_response(qid, question, qtype, qclass, rcode, an_resource_records, ns_resource_records, ar_resource_records)
                found = True
                break
            if not found:
                exception_rcode = 5
                raise Exception("query is not for our domain: %s" % ".".join(question))
        except:
            if qid:
                if exception_rcode is None:
                    exception_rcode = 2
                resp_pkt = format_response(qid, question, qtype, qclass, exception_rcode, [], [], [])
            else:
                continue
        udps.sendto(resp_pkt, src_addr)

def compute_name_server_resources(config):
    ns = []
    ar = []

    # Allow there to be zero nameservers defined...
    if not config.has_key('name_servers'):
	    return ns, ar

    for name_server, ip, ttl in config['name_servers']:
        ns.append({
		'qtype':2,
		'qclass':1,
		'ttl':ttl,
		'rdata':labels2str(name_server),
		'question': config['domain']
		})
        ar.append({'qtype':1, 'qclass':1, 'ttl':ttl, 'rdata':struct.pack("!I", ip)})
    return ns, ar
        
def parse_request(packet):
    hdr_len = 12
    header = packet[:hdr_len]
    qid, flags, qdcount, _, _, _ = struct.unpack('!HHHHHH', header)
    qr = (flags >> 15) & 0x1
    opcode = (flags >> 11) & 0xf
    rd = (flags >> 8) & 0x1
    #print "qid", qid, "qdcount", qdcount, "qr", qr, "opcode", opcode, "rd", rd
    if qr != 0 or opcode != 0 or qdcount == 0:
        raise DnsError("Invalid query")
    body = packet[hdr_len:]
    labels = []
    offset = 0
    while True:
        label_len, = struct.unpack('!B', body[offset:offset+1])
        offset += 1
        if label_len & 0xc0:
            raise DnsError("Invalid label length %d" % label_len)
        if label_len == 0:
            break
        label = body[offset:offset+label_len]
        offset += label_len
        labels.append(label)
    qtype, qclass= struct.unpack("!HH", body[offset:offset+4])
    if qclass != 1:
        raise DnsError("Invalid class: " + qclass)
    return (qid, labels, qtype, qclass)

def format_response(qid, question, qtype, qclass, rcode, an_resource_records, ns_resource_records, ar_resource_records):
    resources = []
    resources.extend(an_resource_records)
    num_an_resources = len(an_resource_records)
    num_ns_resources = num_ar_resources = 0
    if rcode == 0:
        resources.extend(ns_resource_records)
        resources.extend(ar_resource_records)
        num_ns_resources = len(ns_resource_records)
        num_ar_resources = len(ar_resource_records)
    pkt = format_header(qid, rcode, num_an_resources, num_ns_resources, num_ar_resources)
    pkt += format_question(question, qtype, qclass)
    for resource in resources:
        if resource.has_key('question'):
            pkt += format_resource(resource, resource['question'])
	else:
            pkt += format_resource(resource, question)
    return pkt

def format_header(qid, rcode, ancount, nscount, arcount):
    flags = 0
    flags |= (1 << 15)
    flags |= (1 << 10)
    flags |= (rcode & 0xf)
    hdr = struct.pack("!HHHHHH", qid, flags, 1, ancount, nscount, arcount)
    return hdr

def format_question(question, qtype, qclass):
    q = labels2str(question)
    q += struct.pack("!HH", qtype, qclass)
    return q

def format_resource(resource, question):
    r = ''
    r += labels2str(question)
    r += struct.pack("!HHIH", resource['qtype'], resource['qclass'], resource['ttl'], len(resource['rdata']))
    r += resource['rdata']
    return r

def read_config():
    for config_file in config_files:
        config_files[config_file] = config = {}
        config_parser = ConfigParser.SafeConfigParser()
        try:
            config_parser.read(config_file)
            config_values = config_parser.items("default")    
        except:
            die("Error reading config file %s\n" % config_file)

        for var, value in config_values:
            if var == "domain":
                config['domain'] = value.split(".")
            elif var == "name servers":
                config['name_servers'] = []
                split_name_servers = value.split(":")
                num_split_name_servers = len(split_name_servers)
                for i in range(0,num_split_name_servers,3):
                    server = split_name_servers[i]
                    ip = split_name_servers[i+1]
                    ttl = int(split_name_servers[i+2])
                    config['name_servers'].append((server.split("."), ipstr2int(ip), ttl))
            elif var == 'source':
                module_and_args = value.split(":")
                module = module_and_args[0]
                args = module_and_args[1:]
                source_module = __import__(module, {}, {}, [''])
                source_instance = source_module.Source(*args)
                config['source'] = source_instance
            elif var == 'filters':
                config['filters'] = []
                for module_and_args_str in value.split():
                    module_and_args = module_and_args_str.split(":")
                    module = module_and_args[0]
                    args = module_and_args[1:]
                    filter_module = __import__(module, {}, {}, [''])            
                    filter_instance = filter_module.Filter(*args)
                    config['filters'].append(filter_instance)
            else:
                die("unrecognized paramter in conf file %s: %s\n" % (config_file, var))

        if 'domain' not in config or 'source' not in config:
            die("must specify domain name and source in conf file %s\n", config_file)
        sys.stderr.write("read configuration from %s\n" % config_file)

def reread(signum, frame):
    read_config()
    
def die(msg):
    sys.stderr.write(msg)
    sys.exit(-1)

def usage(cmd):
    die("Usage: %s [conf file]\n" % cmd)

config_files = {}
listen_port = 53
listen_host = ''
runas_user = ''
runas_group = ''

try:
    options, filenames = getopt.getopt(sys.argv[1:], "p:h:u:")
except getopt.GetoptError:
    usage(sys.argv[0])

for option, value in options:
    if option == "-p":
        listen_port = int(value)
    elif option == "-h":
        listen_host = value
    elif option == "-u":
        runas_user = value
if not filenames:
    filenames = ['pymds.conf']
for f in filenames:
    if f in config_files:
        raise Exception("repeated configuration")
    config_files[f] = {}

sys.stdout.write("%s starting on port %d\n" % (sys.argv[0], listen_port))
read_config()
signal.signal(signal.SIGHUP, reread)
for config in config_files.values():
    sys.stdout.write("%s: serving for domain %s\n" % (sys.argv[0], ".".join(config['domain'])))
sys.stdout.flush()
sys.stderr.flush()

udps = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
udps.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,0)
udps.bind((listen_host, listen_port))

if runas_user != '':
    import pwd
    runas_user_data = pwd.getpwnam(runas_user)
    os.setgid(runas_user_data[3])
    # Use os.initgroups() if it's available
    try:
        os.initgroups(runas_user_data[0], runas_user_data[3])
    except AttributeError:
        pass
    os.setuid(runas_user_data[2])

serve(udps)
